{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fractalai.model import RandomDiscreteModel\n",
    "from fractalai.environment import ExternalProcess, ParallelEnvironment, AtariEnvironment\n",
    "from fractalai.swarm_wave import SwarmWave\n",
    "from fractalai.dataset_creator import DataGenerator\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SwarmWave?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# This algorithhm is made for solving discrete Markov decision processes when a perfect model is available"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this context, if you need to generate data to train a neural network, IO is now your bottleneck."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "name = \"MsPacman-ram-v0\"\n",
    "skip_frames = 0  # Number of frames to skip at the beginning\n",
    "dt_mean = 5  # Apply the same action n times in average\n",
    "dt_std = 3 # Repeat same action a variable number of times\n",
    "min_dt = 2 # Minimum number of consecutive steps to be taken\n",
    "n_samples = 500000  # Maximum number of samples allowed\n",
    "reward_limit = 20000  # Stop the sampling when this score is reached\n",
    "n_walkers = 60  # Maximum witdh of the tree containing al the trajectories\n",
    "render_every = 1  # print statistics every n iterations.\n",
    "balance = 2  # Balance exploration vs exploitation\n",
    "save_data = True # Save the data generated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/gym/__init__.py:22: UserWarning: DEPRECATION WARNING: to improve load times, gym no longer automatically loads gym.spaces. Please run \"import gym.spaces\" to load gym.spaces on your own. This warning will turn into an error in a future version of gym.\n",
      "  warnings.warn('DEPRECATION WARNING: to improve load times, gym no longer automatically loads gym.spaces. Please run \"import gym.spaces\" to load gym.spaces on your own. This warning will turn into an error in a future version of gym.')\n"
     ]
    }
   ],
   "source": [
    "env = ParallelEnvironment(name=name,env_class=AtariEnvironment,\n",
    "                          blocking=False, n_workers=8, min_dt=1)  # We will play an Atari game\n",
    "model = RandomDiscreteModel(max_wakers=n_walkers,\n",
    "                            n_actions=env.n_actions) # The Agent will take discrete actions at random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "wave = SwarmWave(model=model, env=env, n_walkers=n_walkers, reward_limit=reward_limit, samples_limit=n_samples,\n",
    "                  render_every=render_every, balance=balance, save_data=save_data,\n",
    "                  dt_mean=dt_mean, dt_std=dt_std, accumulate_rewards=True, min_dt=min_dt,\n",
    "                  prune_tree=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Environment: MsPacman-ram-v0 | Walkers: 60 | Deaths: 0 | data_size 100\n",
      "Total samples: 243740 Progress: 100.50%\n",
      "Reward: mean 19912.83 | Dispersion: 230.00 | max 20101.00 | min 19871.00 | std 44.03\n",
      "Episode length: mean 4511.25 | Dispersion 26.00 | max 4527.00 | min 4501.00 | std 6.22\n",
      "Dt: mean 4.87 | Dispersion: 11.00 | max 13.00 | min 2.00 | std 2.68\n",
      "Status: Score limit reached.\n",
      "Efficiency 0.44%\n",
      "Generated 1068 Examples | 228.22 samples per example.\n",
      "\n",
      "CPU times: user 6.27 s, sys: 345 ms, total: 6.61 s\n",
      "Wall time: 31.2 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "wave.run_swarm(print_swarm=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Increase sleep to make game render slower\n",
    "wave.render_game(sleep=0.003)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
